-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the official implementation of c-TPE #177
Conversation
I verified the performance of c-TPE by using the experiment setup of Fig. 3 (Top Row) in the original paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirmed the example in README.md works.
I leave some comments.
optuna==4.1.0 To do this experiment, I extracted HPOLib using hpolib-extractor. Please execute the following command: $ cd ~/hpo_benchmarks/hpolib
$ wget http://ml4aad.org/wp-content/uploads/2019/01/fcnet_tabular_benchmarks.tar.gz
$ tar xf fcnet_tabular_benchmarks.tar.gz
$ mv fcnet_tabular_benchmarks/*.hdf5 .
$ rm -r fcnet_tabular_benchmarks/
$ pip install hpolib-extractor
$ python -c "from hpolib_extractor import extract_hpolib; extract_hpolib(data_dir='./', epochs=[100])" Verification Codefrom __future__ import annotations
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import optuna
import optunahub
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 18
plt.rcParams["mathtext.fontset"] = "stix" # The setting of math font
plt.rcParams["text.usetex"] = True
# Slice Localization with constraint levels of 0.1, 0.5, and 0.9.
# This information is to reproduce the results of Fig. 3 (Top Row).
n_params_key = "n_params"
runtime_key = "runtime"
benchmark_info = {
0.1: {runtime_key: 119.38172, n_params_key: 8401.0, "oracle": 0.0010303777},
0.5: {runtime_key: 366.13934, n_params_key: 52929.0, "oracle": 0.00028238542},
0.9: {runtime_key: 1113.1345, n_params_key: 229633.0, "oracle": 0.00018452932},
}
data_path = f"{os.environ['HOME']}/hpo_benchmarks/hpolib/slice_localization.pkl"
DATASET = pickle.load(open(data_path, mode="rb"))
class HPOLib:
def __init__(self, quantile: float, seed: int | None = None) -> None:
assert quantile in [0.1, 0.5, 0.9]
self._rng = np.random.RandomState(seed)
self._thresholds = {
runtime_key: benchmark_info[quantile][runtime_key],
n_params_key: benchmark_info[quantile][n_params_key],
}
self._oracle = benchmark_info[quantile]["oracle"]
def reseed(self, seed: int | None = None) -> None:
self._rng = np.random.RandomState(seed)
def __call__(self, trial: optuna.Trial) -> float:
param_indices = [
trial.suggest_categorical("activation_fn_1", list(range(2))),
trial.suggest_categorical("activation_fn_2", list(range(2))),
trial.suggest_int("batch_size", 0, 3),
trial.suggest_int("dropout_1", 0, 2),
trial.suggest_int("dropout_2", 0, 2),
trial.suggest_int("init_lr", 0, 5),
trial.suggest_categorical("lr_schedule", list(range(2))),
trial.suggest_int("n_units_1", 0, 5),
trial.suggest_int("n_units_2", 0, 5),
]
config_id = "".join([str(i) for i in param_indices])
seed = self._rng.randint(4)
result = DATASET[config_id]
loss = result["valid_mse"][seed][100]
trial.set_user_attr(runtime_key, result[runtime_key][seed])
trial.set_user_attr(n_params_key, result[n_params_key])
is_feasible = all(
trial.user_attrs[key] <= self._thresholds[key] for key in [n_params_key, runtime_key]
)
trial.set_user_attr("feasible", is_feasible)
return loss
def constraints_func(self, trial: optuna.trial.FrozenTrial) -> tuple[float, float]:
return [
trial.user_attrs[key] - self._thresholds[key] for key in [n_params_key, runtime_key]
]
def compute_absolute_percentage_loss(self, study: optuna.Study) -> np.ndarray:
trials = study.trials
is_feasible = np.array([t.user_attrs["feasible"] for t in trials])
loss_vals = np.array([t.value for t in trials])
loss_vals[~is_feasible] = np.inf
return (np.minimum.accumulate(loss_vals) - self._oracle) / self._oracle
def collect_results(n_seeds: int, n_trials: int, sampler_name: str) -> dict[float, np.ndarray]:
if sampler_name == "ctpe":
package_name = "samplers/ctpe"
repo_owner = "nabenabe0928"
mod = optunahub.load_local_module(package=package_name, registry_root="./package/")
sampler_cls = mod.cTPESampler
elif sampler_name == "random":
sampler_cls = lambda constraints_func, seed: optuna.samplers.RandomSampler(seed)
elif sampler_name == "nsgaii":
sampler_cls = lambda constraints_func, seed: optuna.samplers.NSGAIISampler(
seed=seed, constraints_func=constraints_func, population_size=8
)
elif sampler_name == "tpe":
sampler_cls = lambda constraints_func, seed: optuna.samplers.TPESampler(
multivariate=True, constraints_func=constraints_func, seed=seed
)
else:
assert False, sampler_name
results = {0.1: [], 0.5: [], 0.9: []}
for q in results:
hpolib = HPOLib(quantile=q)
for seed in range(n_seeds):
print(f"Start with {sampler_name=}, {q=}, and {seed=}.")
hpolib.reseed(seed)
sampler = sampler_cls(constraints_func=hpolib.constraints_func, seed=seed)
study = optuna.create_study(sampler=sampler)
study.optimize(hpolib, n_trials=n_trials)
results[q].append(hpolib.compute_absolute_percentage_loss(study))
return {k: np.asarray(v) for k, v in results.items()}
def visualize(axes: plt.Axes, n_seeds: int, n_trials: int, sampler_name: str) -> plt.Line2D:
results = collect_results(n_seeds, n_trials, sampler_name)
dx = np.arange(n_trials) + 1
color = {
"ctpe": "red",
"nsgaii": "magenta",
"random": "olive",
"tpe": "blue",
}[sampler_name]
for ax, (q, res) in zip(axes, results.items()):
m = np.mean(res, axis=0)
s = np.std(res, axis=0) / np.sqrt(n_seeds)
line, = ax.plot(dx, m, color=color)
ax.fill_between(dx, m - s, m + s, color=color, alpha=0.2)
ax.set_title(f"Quantile: {q:.1f}")
return line
def main(n_seeds: int, n_trials: int) -> None:
fig, axes = plt.subplots(
ncols=3, figsize=(27, 4.5), sharex=True, sharey=True, gridspec_kw={"wspace": 0.05}
)
for ax in axes:
ax.grid(which="minor", color="gray", linestyle=":")
ax.grid(which="major", color="black")
ax.set_yscale("log")
ax.set_xlim(1, n_trials)
ax.set_ylim(0.1, 2000)
# sampler_names = ["ctpe", "tpe", "nsgaii", "random"]
sampler_names = ["ctpe", "tpe", "nsgaii", "random"]
labels = [
{"ctpe": "c-TPE", "tpe": "Optuna TPE", "nsgaii": "CNSGA-II", "random": "Random"}[name]
for name in sampler_names
]
lines = [visualize(axes, n_seeds, n_trials, sampler_name) for sampler_name in sampler_names]
fig.supxlabel("\# of Trials", y=-0.04)
fig.supylabel("Absolute Percentage Loss", x=0.09)
fig.legend(
handles=lines,
labels=labels,
loc="lower center",
ncol=len(labels),
fontsize=24,
bbox_to_anchor=(0.5, -0.25),
)
plt.savefig("slice-localization.png", bbox_inches="tight")
optuna.logging.set_verbosity(optuna.logging.CRITICAL)
main(n_seeds=50, n_trials=200) |
76ce560
to
e131f63
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Published! |
Contributor Agreements
Please read the contributor agreements and if you agree, please click the checkbox below.
Tip
Please follow the Quick TODO list to smoothly merge your PR.
Motivation
This PR is to migrate the official implementation of c-TPE into OptunaHub.
TODO List towards PR Merge
Please remove this section if this PR is not an addition of a new package.
Otherwise, please check the following TODO list:
./template/
to create your package<COPYRIGHT HOLDER>
inLICENSE
of your package with your nameREADME.md
in your package__init__.py
from __future__ import annotations
at the head of any Python files that include typing to support older Python versionsREADME.md
README.md